{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": "changeing train\nepoch 1,loss 0.0031,train acc 0.707,test_acc 0.731\nepoch 2,loss 0.0019,train acc 0.823,test_acc 0.801\nepoch 3,loss 0.0017,train acc 0.844,test_acc 0.839\nepoch 4,loss 0.0015,train acc 0.856,test_acc 0.839\nepoch 5,loss 0.0015,train acc 0.861,test_acc 0.850\n"
    }
   ],
   "source": [
    "import numpy as np \n",
    "import torch\n",
    "import torchvision\n",
    "import d2lzh as d2l\n",
    "\n",
    "# 数据集\n",
    "batch_size=256\n",
    "train_iter,test_iter = d2l.load_data_fashion_mnist(batch_size)\n",
    "\n",
    "# 参数\n",
    "num_inputs ,num_outputs,num_hiddens = 784,10,256\n",
    "\n",
    "W1 = torch.tensor(np.random.normal(0,0.01,(num_inputs,num_hiddens)),dtype=torch.float32)\n",
    "b1 = torch.zeros(num_hiddens,dtype=torch.float32)\n",
    "\n",
    "W2 = torch.tensor(np.random.normal(0,0.01,(num_hiddens,num_outputs)),dtype=torch.float32)\n",
    "b2 = torch.zeros(num_outputs,dtype=torch.float32)\n",
    "\n",
    "params = [W1,b1,W2,b2]\n",
    "for param in params:\n",
    "    param.requires_grad_(requires_grad=True)\n",
    "\n",
    "def relu(X):\n",
    "    return torch.max(input=X,other=torch.tensor(0.0))\n",
    "\n",
    "# 模型\n",
    "def net(X):\n",
    "    X = X.view((-1,num_inputs))\n",
    "    H = relu(torch.matmul(X,W1) + b1)\n",
    "    return torch.matmul(H,W2) + b2\n",
    "\n",
    "# 损失函数\n",
    "loss = torch.nn.CrossEntropyLoss()\n",
    "\n",
    "# 训练\n",
    "num_epochs = 5\n",
    "lr = 100.0\n",
    "d2l.train_ch3(net,train_iter,test_iter,loss,num_epochs,batch_size,params,lr)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": "changeing train\nepoch 1,loss 0.0031,train acc 0.703,test_acc 0.712\nepoch 2,loss 0.0019,train acc 0.823,test_acc 0.782\nepoch 3,loss 0.0017,train acc 0.842,test_acc 0.841\nepoch 4,loss 0.0015,train acc 0.856,test_acc 0.843\nepoch 5,loss 0.0014,train acc 0.863,test_acc 0.834\n"
    }
   ],
   "source": [
    "# 简洁实现\n",
    "import torch.nn as nn \n",
    "\n",
    "class FlattenLayer(torch.nn.Module):\n",
    "    def __init__(self):\n",
    "        super(FlattenLayer,self).__init__()\n",
    "\n",
    "    def forward(self,X):\n",
    "        return X.view(X.shape[0],-1)\n",
    "\n",
    "\n",
    "net = torch.nn.Sequential(FlattenLayer(),\n",
    "    nn.Linear(num_inputs,num_hiddens),\n",
    "    nn.ReLU(),\n",
    "    nn.Linear(num_hiddens,num_outputs))\n",
    "\n",
    "for params in net.parameters():\n",
    "    nn.init.normal_(params,mean=0,std=0.01)\n",
    "\n",
    "optimizer = torch.optim.SGD(net.parameters(),lr=0.5)\n",
    "\n",
    "d2l.train_ch3(net,train_iter,test_iter,loss,num_epochs,batch_size,None,None,optimizer)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.7-final"
  },
  "orig_nbformat": 2,
  "kernelspec": {
   "name": "python37764bitpytorchnotebookconda6e7a8693df0d4d92aca09d521275d23a",
   "display_name": "Python 3.7.7 64-bit ('pytorch_notebook': conda)"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}